{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6tzp2bPEiK_S"
      },
      "source": [
        "##### Copyright 2023 The TF-Agents Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "E2347LPWgmcO"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0lfjAG3IiHSU"
      },
      "source": [
        "# Tutorial on Ranking in TF-Agents"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "OQJdLZ636rDN"
      },
      "source": [
        "### Get Started\n",
        "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/agents/tutorials/ranking_tutorial\"\u003e\n",
        "    \u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003e\n",
        "    View on TensorFlow.org\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/agents/blob/master/docs/tutorials/ranking_tutorial.ipynb\"\u003e\n",
        "    \u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003e\n",
        "    Run in Google Colab\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/agents/blob/master/docs/tutorials/ranking_tutorial.ipynb\"\u003e\n",
        "    \u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003e\n",
        "    View source on GitHub\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/agents/docs/tutorials/ranking_tutorial.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "\u003c/table\u003e\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ql6S68mZ6hMG"
      },
      "source": [
        "### Setup"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tFs2W62pqUxk"
      },
      "outputs": [],
      "source": [
        "!pip install tf-agents[reverb]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1dbfZarwmB96"
      },
      "outputs": [],
      "source": [
        "#@title Imports\n",
        "import matplotlib.pyplot as plt\n",
        "import numpy as np\n",
        "import tensorflow as tf\n",
        "\n",
        "from tf_agents.bandits.agents import ranking_agent\n",
        "from tf_agents.bandits.agents.examples.v2 import trainer\n",
        "from tf_agents.bandits.environments import ranking_environment\n",
        "from tf_agents.bandits.networks import global_and_arm_feature_network\n",
        "from tf_agents.environments import tf_py_environment\n",
        "from tf_agents.bandits.policies import ranking_policy\n",
        "from tf_agents.bandits.replay_buffers import bandit_replay_buffer\n",
        "from tf_agents.drivers import dynamic_step_driver\n",
        "from tf_agents.metrics import tf_metrics\n",
        "from tf_agents.specs import bandit_spec_utils\n",
        "from tf_agents.specs import tensor_spec\n",
        "from tf_agents.trajectories import trajectory"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "l-KqVvl_g9El"
      },
      "source": [
        "# Introduction"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "OoHE8C13g2O1"
      },
      "source": [
        "In this tutorial, we guide you through the ranking algorithms implemented as part of the TF-Agents Bandits library. In a ranking problem, in every iteration an agent is presented with a set of items, and is tasked with ranking some or all of them to a list. This ranking decision then receives some form of feedback (maybe a user does or does not click on one or more of the selected items for example). The goal of the agent is to optimize some metric/reward with the goal of making better decisions over time."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-chVGcJVll1G"
      },
      "source": [
        "# Prerequisites"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "M92k7X27lqOM"
      },
      "source": [
        "The ranking algorithms in TF-Agents belong to a special type of bandit agents that operate on \"per-arm\" bandit problems. Hence, to be able to benefit the most from this tutorial, the reader should familiarize themselves with the [bandit](https://github.com/tensorflow/agents/tree/master/docs/tutorials/bandits_tutorial.ipynb) and the [per-arm bandit](https://github.com/tensorflow/agents/tree/master/docs/tutorials/per_arm_bandits_tutorial.ipynb) tutorials."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "o3gxWMIGvhNX"
      },
      "source": [
        "# The Ranking Problem and its Variants"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ek9XbdjDvlFs"
      },
      "source": [
        "For this tutorial, we will use the example of presenting items for sale to users. In every iteration, we receive a set of items and possibly a number describing how many of them we should display. We assume the number of items at hand is always greater than or equal to the number of slots to place them in. We need to fill the slots in the display to maximize the probability that the user will interact with one or more of the displayed items. The user, as well as the items, are described by *features*.\n",
        "\n",
        "If we manage to put items on display that are liked by the user, the probability of user/item interactions increases. Hence, it's a good idea to learn how user-item pairs match. But how do we know if an item is liked by the user? To this end, we introduce *Feedback Types*."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cTZ9RvYrDM2u"
      },
      "source": [
        "#Feedback Types"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QNWZmMRoDPRX"
      },
      "source": [
        "As opposed to bandit problems where the feedback signal (the reward) is directly associated with a single chosen item, in ranking we need to consider how the feedback translates to the \"goodness\" of the displayed items. In other words, we need to assign scores to all or some of the displayed items. In our library we offer two different feedback types: *vector feedback* and *cascading feedback*."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "x5c71Vyrul4z"
      },
      "source": [
        "## Vector Feedback"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ekcxCX-Ru8I1"
      },
      "source": [
        "In the vector feedback type, we assume that the agent receives a scalar score for every item in the output ranking. These scalars are put together in a vector in the same ordering as the output ranking. Thus the feedback is a vector of the same size as the number of elements in the ranking.\n",
        "\n",
        "This feedback type is quite straightforward in the sense that we don't need to worry about converting feedback signals to scores. On the other hand, the responsibility of scoring items falls on the designer (aka. you): it's up to the system designer to decide what scores to give based on the item, its position, and whether it was interacted with by the user."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "p9mnWzWbu3II"
      },
      "source": [
        "##Cascading Feedback"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zUk2__0CvsLz"
      },
      "source": [
        "In the cascading feedback type (the term coined by [Craswell et al., 2008](https://dl.acm.org/doi/abs/10.1145/1341531.1341545)), we assume the user looks at the displayed items in a sequential manner, starting at the top slot. As soon as the user finds an item worthy of clicking, they click and never return to the current ranked list. They don't even look at items below the item clicked. Not clicking on any item is also a possibility, this happens when none of the displayed items are worthy of clicking. In this case, the user does look at all the items.\n",
        "\n",
        "The feedback signal is composed of two elements: The index of the chosen element, and the value of the click. Then it is the agent's task to translate this information to scores. In our implementation in the bandit library, we implemented the convention that seen but unclicked items receive some low score (typically 0 or -1), the clicked item receives the click value, and the items beyond the clicked one are ignored by the agent."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Ey94cmzkDUP7"
      },
      "source": [
        "# Diversity and Exploration"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7JcMNvBDDX_b"
      },
      "source": [
        "To maximize the chances of the user clicking on an item, it's not enough to just choose the highest scoring items and put them high in the ranking. For a user with a lot of different interests, they might be most interested in sports, but they also like arts and traveling. Giving all the sporty items the highest estimated scores and displaying all sporty items in the highest slots may not be optimal. The user might be in the mood for arts or traveling. Hence, it is a good idea to display a mix of the high-scoring interests. It is important to not only maximize the score of the displayed items but also make sure they form a diverse set.\n",
        "\n",
        "As with other limited-information learning problems (like bandits), it is also important to keep in mind that our decisions not only affect the immediate reward, but also the training data and future reward. If we always only display items based on their current estimated score, we might be missing out on high-scoring items that we haven't explored enough yet, and thus we are not aware of how good they are. That is, we need to incorporate exploration to our decision making process.\n",
        "\n",
        "All of the above concepts and considerations are addressed in our library. In this tutorial we walk you through the details."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jwWucISQQSGt"
      },
      "source": [
        "# Simulating Users: Our Test Environment"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lxLWjKe1Q2Xz"
      },
      "source": [
        "Let's dive into our codebase!\n",
        "\n",
        "First we define the environment, the class responsible for randomly generating user and item features, and give feedback after decisions."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "sR8Id9Y7mhBK"
      },
      "outputs": [],
      "source": [
        "feedback_model = ranking_environment.FeedbackModel.CASCADING #@param[\"ranking_environment.FeedbackModel.SCORE_VECTOR\", \"ranking_environment.FeedbackModel.CASCADING\"] {type:\"raw\"}\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "M38wgZHqcI-M"
      },
      "source": [
        "We also need a model for the environment to decide when to *not click*. We have two ways in our library, *distance based* and *ghost actions*.\n",
        "\n",
        "\n",
        "*   In distance based, if the user features are not close enough to any of the item features, the user does not click.\n",
        "*   In the ghost actions model, we set up extra imaginary actions in the form of unit vector item features. If the user chooses one of the ghost actions, it results in a no-click.\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "3b1hYbTxi6Kg"
      },
      "outputs": [],
      "source": [
        "click_type = \"ghost_actions\"  #@param[\"distance_based\", \"ghost_actions\"]\n",
        "click_model = (ranking_environment.ClickModel.DISTANCE_BASED\n",
        "               if click_type == \"distance_based\" else\n",
        "               ranking_environment.ClickModel.GHOST_ACTIONS)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1Ut1nnmB6_9T"
      },
      "source": [
        "We are almost ready to define the ranking environment, just a couple of preparations: we define the sampling functions for the global (user) and the item features. These features will be used by the environment to simulate user behavior: a weighted inner product of the global and item features is calculated, and the probability of the user clicking is proportional to the inner product values. The weighting of the inner product is defined by `scores_weight_matrix` below."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ueVe8Gf77Lad"
      },
      "outputs": [],
      "source": [
        "global_dim = 9  #@param{ type: \"integer\"}\n",
        "item_dim   = 11  #@param{ type: \"integer\"}\n",
        "num_items  = 50 #@param{ type: \"integer\"}\n",
        "num_slots  = 3  #@param{ type: \"integer\"}\n",
        "distance_threshold = 5.0  #@param{ type: \"number\" }\n",
        "batch_size = 128   #@param{ type: \"integer\"}\n",
        "\n",
        "def global_sampling_fn():\n",
        "  return np.random.randint(-1, 1, [global_dim]).astype(np.float32)\n",
        "\n",
        "def item_sampling_fn():\n",
        "  return np.random.randint(-2, 3, [item_dim]).astype(np.float32)\n",
        "\n",
        "# Inner product with excess dimensions ignored.\n",
        "scores_weight_matrix = np.eye(11, 9, dtype=np.float32)\n",
        "\n",
        "env = ranking_environment.RankingPyEnvironment(\n",
        "    global_sampling_fn,\n",
        "    item_sampling_fn,\n",
        "    num_items=num_items,\n",
        "    num_slots=num_slots,\n",
        "    scores_weight_matrix=scores_weight_matrix,\n",
        "    feedback_model=feedback_model,\n",
        "    click_model=click_model,\n",
        "    distance_threshold=distance_threshold,\n",
        "    batch_size=batch_size)\n",
        "\n",
        "# Convert the python environment to tf environment.\n",
        "environment = tf_py_environment.TFPyEnvironment(env)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wN2aFkL_qHJe"
      },
      "source": [
        "Now let's define a few different agents that will tackle the above environment! All of the agents train a network that estimates scores of item/user pairs. The difference lies in the policy, that is, how the trained network is used to make a ranking decision. The implemented policies span from just stack ranking based on scores to taking into account diversity and exploration with the ability to tune the mixture of these aspects."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "MBJgkyFC64rR"
      },
      "outputs": [],
      "source": [
        "#@title Defining the Network and Training Params\n",
        "scoring_network = (\n",
        "      global_and_arm_feature_network.create_feed_forward_common_tower_network(\n",
        "          environment.observation_spec(), (20, 10), (20, 10), (20, 10)))\n",
        "learning_rate = 0.005  #@param{ type: \"number\"}\n",
        "\n",
        "feedback_dict = {ranking_environment.FeedbackModel.CASCADING: ranking_agent.FeedbackModel.CASCADING,\n",
        "                 ranking_environment.FeedbackModel.SCORE_VECTOR: ranking_agent.FeedbackModel.SCORE_VECTOR}\n",
        "agent_feedback_model = feedback_dict[feedback_model]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ufuiBZsg5YrR"
      },
      "outputs": [],
      "source": [
        "#@title Stack Ranking Deterministically by Scores\n",
        "\n",
        "policy_type = ranking_agent.RankingPolicyType.DESCENDING_SCORES\n",
        "descending_scores_agent = ranking_agent.RankingAgent(\n",
        "    time_step_spec=environment.time_step_spec(),\n",
        "    action_spec=environment.action_spec(),\n",
        "    scoring_network=scoring_network,\n",
        "    optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate),\n",
        "    feedback_model=agent_feedback_model,\n",
        "    policy_type=policy_type,\n",
        "    summarize_grads_and_vars=True)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8ko9D3qO7gUs"
      },
      "outputs": [],
      "source": [
        "#@title Sampling Sequentially Based on Scores\n",
        "\n",
        "policy_type = ranking_agent.RankingPolicyType.NO_PENALTY\n",
        "logits_temperature = 1.0  #@param{ type: \"number\" }\n",
        "\n",
        "no_penalty_agent = ranking_agent.RankingAgent(\n",
        "    time_step_spec=environment.time_step_spec(),\n",
        "    action_spec=environment.action_spec(),\n",
        "    scoring_network=scoring_network,\n",
        "    optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate),\n",
        "    feedback_model=agent_feedback_model,\n",
        "    policy_type=policy_type,\n",
        "    logits_temperature=logits_temperature,\n",
        "    summarize_grads_and_vars=True)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "Gnv5c4w094A4"
      },
      "outputs": [],
      "source": [
        "#@title Sampling Sequentally and Taking Diversity into Account\n",
        "#@markdown The balance between ranking based on scores and taking diversity into account is governed by the following \"penalty mixture\" parameter. A low positive value results in rankings that hardly mix in diversity, a higher value will enforce more diversity.\n",
        "\n",
        "policy_type = ranking_agent.RankingPolicyType.COSINE_DISTANCE\n",
        "penalty_mixture = 1.0 #@param{ type: \"number\"}\n",
        "\n",
        "cosine_distance_agent = ranking_agent.RankingAgent(\n",
        "    time_step_spec=environment.time_step_spec(),\n",
        "    action_spec=environment.action_spec(),\n",
        "    scoring_network=scoring_network,\n",
        "    optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate),\n",
        "    feedback_model=agent_feedback_model,\n",
        "    policy_type=policy_type,\n",
        "    logits_temperature=logits_temperature,\n",
        "    penalty_mixture_coefficient=penalty_mixture,\n",
        "    summarize_grads_and_vars=True)\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ZMImW7rrWn5w"
      },
      "outputs": [],
      "source": [
        "#@title Choosing the desired agent.\n",
        "agent_type = \"cosine_distance_agent\" #@param[\"cosine_distance_agent\", \"no_penalty_agent\", \"descending_scores_agent\"]\n",
        "if agent_type == \"descending_scores_agent\":\n",
        "  agent = descending_scores_agent\n",
        "elif agent_type == \"no_penalty_agent\":\n",
        "  agent = no_penalty_agent\n",
        "else:\n",
        "  agent = cosine_distance_agent"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SYQ2sCaz6tTX"
      },
      "source": [
        "Before we can start our training loop, there is one more thing we need to take care of, concerning the training data.\n",
        "\n",
        "The arm features presented to the policy at decision time contains all items that the policy can choose from. However, at training, we need the features of items that were selected, and for convenience, in the order of the decision output. To this end, the following function is used (copied here for clarity from [here](https://github.com/tensorflow/agents/tree/master/tf_agents/bandits/agents/examples/v2/train_eval_ranking.py)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vpUfQXgD7y0t"
      },
      "outputs": [],
      "source": [
        "def order_items_from_action_fn(orig_trajectory):\n",
        "  \"\"\"Puts the features of the selected items in the recommendation order.\n",
        "\n",
        "  This function is used to make sure that at training the item observation is\n",
        "  filled with features of items selected by the policy, in the order of the\n",
        "  selection. Features of unselected items are discarded.\n",
        "\n",
        "  Args:\n",
        "    orig_trajectory: The trajectory as output by the policy\n",
        "\n",
        "  Returns:\n",
        "    The modified trajectory that contains slotted item features.\n",
        "  \"\"\"\n",
        "  item_obs = orig_trajectory.observation[\n",
        "      bandit_spec_utils.PER_ARM_FEATURE_KEY]\n",
        "  action = orig_trajectory.action\n",
        "  if isinstance(\n",
        "      orig_trajectory.observation[bandit_spec_utils.PER_ARM_FEATURE_KEY],\n",
        "      tensor_spec.TensorSpec):\n",
        "    dtype = orig_trajectory.observation[\n",
        "        bandit_spec_utils.PER_ARM_FEATURE_KEY].dtype\n",
        "    shape = [\n",
        "        num_slots, orig_trajectory.observation[\n",
        "            bandit_spec_utils.PER_ARM_FEATURE_KEY].shape[-1]\n",
        "    ]\n",
        "    new_observation = {\n",
        "        bandit_spec_utils.GLOBAL_FEATURE_KEY:\n",
        "            orig_trajectory.observation[bandit_spec_utils.GLOBAL_FEATURE_KEY],\n",
        "        bandit_spec_utils.PER_ARM_FEATURE_KEY:\n",
        "            tensor_spec.TensorSpec(dtype=dtype, shape=shape)\n",
        "    }\n",
        "  else:\n",
        "    slotted_items = tf.gather(item_obs, action, batch_dims=1)\n",
        "    new_observation = {\n",
        "        bandit_spec_utils.GLOBAL_FEATURE_KEY:\n",
        "            orig_trajectory.observation[bandit_spec_utils.GLOBAL_FEATURE_KEY],\n",
        "        bandit_spec_utils.PER_ARM_FEATURE_KEY:\n",
        "            slotted_items\n",
        "    }\n",
        "  return trajectory.Trajectory(\n",
        "      step_type=orig_trajectory.step_type,\n",
        "      observation=new_observation,\n",
        "      action=(),\n",
        "      policy_info=(),\n",
        "      next_step_type=orig_trajectory.next_step_type,\n",
        "      reward=orig_trajectory.reward,\n",
        "      discount=orig_trajectory.discount)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "VQDWXgDUsCZ1"
      },
      "outputs": [],
      "source": [
        "#@title Defininfing Parameters to Run the Agent on the Defined Environment\n",
        "num_iterations = 400 #@param{ type: \"number\" }\n",
        "steps_per_loop = 2   #@param{ type: \"integer\" }"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Na2ZHarNVS0-"
      },
      "source": [
        "As in the bandit tutorials, we define the replay buffer that will feed the agent the samples to train on. Then, we use the driver to put everything together: The environment provides features, the policy chooses rankings, and samples are collected to be trained on."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qt6ifI5AYWfu"
      },
      "outputs": [],
      "source": [
        "replay_buffer = bandit_replay_buffer.BanditReplayBuffer(\n",
        "      data_spec=order_items_from_action_fn(agent.policy.trajectory_spec),\n",
        "      batch_size=batch_size,\n",
        "      max_length=steps_per_loop)\n",
        "\n",
        "if feedback_model == ranking_environment.FeedbackModel.SCORE_VECTOR:\n",
        "  reward_metric = tf_metrics.AverageReturnMetric(\n",
        "      batch_size=environment.batch_size,\n",
        "      buffer_size=200)\n",
        "else:\n",
        "  reward_metric = tf_metrics.AverageReturnMultiMetric(\n",
        "        reward_spec=environment.reward_spec(),\n",
        "        batch_size=environment.batch_size,\n",
        "        buffer_size=200)\n",
        "\n",
        "add_batch_fn = lambda data: replay_buffer.add_batch(\n",
        "        order_items_from_action_fn(data))\n",
        "\n",
        "observers = [add_batch_fn, reward_metric]\n",
        "\n",
        "driver = dynamic_step_driver.DynamicStepDriver(\n",
        "    env=environment,\n",
        "    policy=agent.collect_policy,\n",
        "    num_steps=steps_per_loop * batch_size,\n",
        "    observers=observers)\n",
        "\n",
        "reward_values = []\n",
        "\n",
        "for _ in range(num_iterations):\n",
        "  driver.run()\n",
        "  loss_info = agent.train(replay_buffer.gather_all())\n",
        "  replay_buffer.clear()\n",
        "  if feedback_model == ranking_environment.FeedbackModel.SCORE_VECTOR:\n",
        "    reward_values.append(reward_metric.result())\n",
        "  else:\n",
        "    reward_values.append(reward_metric.result())\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Gwy7cQP3JrU0"
      },
      "source": [
        "Let's plot the reward!"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "eZPOXzfyy5Sh"
      },
      "outputs": [],
      "source": [
        "if feedback_model == ranking_environment.FeedbackModel.SCORE_VECTOR:\n",
        "  reward = reward_values\n",
        "else:\n",
        "  reward = [r[\"chosen_value\"] for r in reward_values]\n",
        "plt.plot(reward)\n",
        "plt.ylabel('Average Return')\n",
        "plt.xlabel('Number of Iterations')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gvbm_qCGJy0D"
      },
      "source": [
        "# What's Next?\n",
        "\n",
        "This tutorial has lots of tunable parameters, including the policy/agent to use, some properties of the environment, and even the feedback model. Feel free to experiment with those parameters!\n",
        "\n",
        "There is also a ready-to-run example for ranking in `tf_agents/bandits/agents/examples/v2/train_eval_ranking.py`"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "ranking_tutorial.ipynb",
      "private_outputs": true,
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
